# JAX python snippets.

snippet jnp  "Import jnp"
import jax.numpy as jnp

endsnippet
snippet jax.rng
${1:rng}, key = jax.random.split($1)
endsnippet
snippet jax.jit
@functools.partial(jax.jit, static_argnames=('${1}', ))
endsnippet

snippet jax_gpu_growth "Let GPU do not allocate a full memory at once."
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
endsnippet
